"""
Phase 6: Robustness Tests
Development Threshold Paper

Alternative thresholds, MER GDP, exclude oil, period splits, regional controls.
"""

import pandas as pd
import numpy as np
import statsmodels.api as sm
from scipy import stats
import sys
sys.path.insert(0, '/mnt/c/demographics_capital_flows/multilateral/src')
from model import PanelGLS

# ═══════════════════════════════════════════════════════════════════════
# Load data
# ═══════════════════════════════════════════════════════════════════════
full = pd.read_csv('/mnt/c/demographics_capital_flows/multilateral/140_country/data/processed/full_panel.csv')
panel = full[full['year'] <= 2024].copy()

try:
    rents = pd.read_csv('/mnt/c/demographics_capital_flows/multilateral/140_country/data/raw/wdi_resource_rents.csv')
    panel = panel.merge(rents[['iso3', 'year', 'resource_rents_gdp']], on=['iso3', 'year'], how='left')
except:
    panel['resource_rents_gdp'] = np.nan

cross_df = pd.read_csv('/mnt/c/demographics_capital_flows/development_threshold/data/crossing_data.csv')

# ═══════════════════════════════════════════════════════════════════════
# 6a. Alternative thresholds
# ═══════════════════════════════════════════════════════════════════════
print("=" * 70)
print("6a. ALTERNATIVE THRESHOLDS")
print("=" * 70)

threshold_sets = [
    ('$7k-$20k', 7000, 20000),
    ('$9k-$25k (baseline)', 9000, 25000),
    ('$10k-$30k', 10000, 30000),
    ('$12k-$30k', 12000, 30000),
    ('$8k-$22k', 8000, 22000),
]

robustness_results = []

for label, lower, upper in threshold_sets:
    # Classify countries
    countries = panel[panel['gdp_pc_ppp'].notna()].groupby('iso3').agg(
        first_gdp=('gdp_pc_ppp', 'first'),
        last_gdp=('gdp_pc_ppp', 'last'),
    )

    crossed = ((countries['first_gdp'] < lower) & (countries['last_gdp'] > upper)).sum()
    in_zone = ((countries['last_gdp'] >= lower) & (countries['last_gdp'] <= upper)).sum()

    # Quick logit: does Z₁ predict crossing?
    zone_entrants = []
    for iso in panel['iso3'].unique():
        c = panel[panel['iso3'] == iso].sort_values('year')
        gdp_s = c[c['gdp_pc_ppp'].notna()]
        if len(gdp_s) < 5:
            continue

        above_lower = gdp_s[gdp_s['gdp_pc_ppp'] >= lower]
        above_upper = gdp_s[gdp_s['gdp_pc_ppp'] >= upper]

        if len(above_lower) > 0:
            entry_year = above_lower['year'].min()
            entry_row = c[c['year'] == entry_year]
            if len(entry_row) > 0:
                zone_entrants.append({
                    'iso3': iso,
                    'exited_above': 1 if len(above_upper) > 0 else 0,
                    'Z_1': entry_row['Z_1'].values[0],
                    'gdp_pc_ppp': entry_row['gdp_pc_ppp'].values[0],
                })

    ze = pd.DataFrame(zone_entrants)
    ze_clean = ze[['exited_above', 'Z_1', 'gdp_pc_ppp']].dropna()

    if len(ze_clean) >= 20 and ze_clean['exited_above'].sum() >= 3:
        try:
            model = sm.Logit(ze_clean['exited_above'],
                           sm.add_constant(ze_clean[['Z_1', 'gdp_pc_ppp']])).fit(disp=0)
            z1_p = model.pvalues['Z_1']
            z1_coef = model.params['Z_1']
            pr2 = model.prsquared
        except:
            z1_p = np.nan
            z1_coef = np.nan
            pr2 = np.nan
    else:
        z1_p = np.nan
        z1_coef = np.nan
        pr2 = np.nan

    sig = '***' if z1_p < 0.01 else '**' if z1_p < 0.05 else '*' if z1_p < 0.1 else '' if pd.notna(z1_p) else ''
    print(f"\n  {label}: crossed={crossed}, in_zone={in_zone}, "
          f"Z₁ coef={z1_coef:.3f}{sig}, p={z1_p:.4f}, R²={pr2:.3f}" if pd.notna(z1_coef) else
          f"\n  {label}: crossed={crossed}, in_zone={in_zone}, logit failed")

    robustness_results.append({
        'threshold': label, 'lower': lower, 'upper': upper,
        'n_crossed': crossed, 'n_in_zone': in_zone,
        'z1_coef': z1_coef, 'z1_pvalue': z1_p, 'pseudo_r2': pr2,
        'n_entrants': len(ze_clean)
    })

# ═══════════════════════════════════════════════════════════════════════
# 6b. Exclude oil exporters
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("6b. EXCLUDE OIL EXPORTERS")
print("=" * 70)

# Identify oil exporters (resource rents >= 10% at any point)
oil_exporters = panel[panel['resource_rents_gdp'] >= 10]['iso3'].unique()
print(f"Oil/commodity exporters excluded: {len(oil_exporters)}")

zone_entrants_hist = cross_df[cross_df['cross_9k'].notna() & (cross_df['status'] != 'Always above')].copy()
ze_no_oil = zone_entrants_hist[~zone_entrants_hist['iso3'].isin(oil_exporters)]

logit_vars = ['Z_1_entry', 'gdp_pc_ppp_entry']
ze_sample = ze_no_oil[['exited_above'] + logit_vars].dropna()

if len(ze_sample) >= 15:
    model_no_oil = sm.Logit(ze_sample['exited_above'],
                            sm.add_constant(ze_sample[logit_vars])).fit(disp=0)
    print(f"\nLogit excluding oil exporters (n={len(ze_sample)}):")
    for var in logit_vars:
        sig = '***' if model_no_oil.pvalues[var] < 0.01 else '**' if model_no_oil.pvalues[var] < 0.05 else '*' if model_no_oil.pvalues[var] < 0.1 else ''
        print(f"  {var}: coef={model_no_oil.params[var]:.4f}, p={model_no_oil.pvalues[var]:.4f}{sig}")
    print(f"  pseudo-R²={model_no_oil.prsquared:.3f}")

    # How many crossers were oil exporters?
    oil_crossers = zone_entrants_hist[(zone_entrants_hist['exited_above'] == 1) &
                                      (zone_entrants_hist['iso3'].isin(oil_exporters))]
    print(f"\n  Oil exporters that crossed: {len(oil_crossers)}")
    for _, row in oil_crossers.iterrows():
        print(f"    {row['iso3']}")

# ═══════════════════════════════════════════════════════════════════════
# 6c. Period splits: pre-2000 vs post-2000
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("6c. PERIOD SPLITS")
print("=" * 70)

for period_label, min_year, max_year in [('Pre-2000 entrants', 1950, 1999),
                                          ('Post-2000 entrants', 2000, 2024)]:
    sub = zone_entrants_hist[(zone_entrants_hist['cross_9k'] >= min_year) &
                             (zone_entrants_hist['cross_9k'] <= max_year)]
    n_crossed = sub['exited_above'].sum()
    n_total = len(sub)

    print(f"\n  {period_label}: {n_total} entrants, {n_crossed} crossed")

    ze_sample = sub[['exited_above', 'Z_1_entry', 'gdp_pc_ppp_entry']].dropna()
    if len(ze_sample) >= 10 and ze_sample['exited_above'].sum() >= 2:
        # Bivariate Z₁ comparison
        c_z1 = ze_sample[ze_sample['exited_above'] == 1]['Z_1_entry']
        nc_z1 = ze_sample[ze_sample['exited_above'] == 0]['Z_1_entry']
        if len(c_z1) >= 2 and len(nc_z1) >= 2:
            t, p = stats.ttest_ind(c_z1, nc_z1, equal_var=False)
            print(f"    Z₁ crossers={c_z1.mean():.3f}, non-crossers={nc_z1.mean():.3f}, p={p:.4f}")

# ═══════════════════════════════════════════════════════════════════════
# 6d. Regional controls
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("6d. REGIONAL CONTROLS: IS THIS JUST EAST ASIA?")
print("=" * 70)

# Assign regions
region_map = {
    'East Asia': ['CHN', 'KOR', 'TWN', 'HKG', 'SGP', 'JPN', 'MYS', 'THA', 'IDN', 'PHL', 'VNM', 'KHM', 'MMR', 'LAO', 'MNG'],
    'South Asia': ['IND', 'BGD', 'PAK', 'LKA', 'NPL'],
    'Europe': ['POL', 'CZE', 'HUN', 'SVK', 'SVN', 'HRV', 'ROU', 'BGR', 'SRB', 'MKD', 'ALB', 'BIH',
               'GBR', 'FRA', 'DEU', 'ITA', 'ESP', 'PRT', 'GRC', 'NLD', 'BEL', 'AUT', 'CHE', 'SWE',
               'NOR', 'DNK', 'FIN', 'IRL', 'LUX', 'EST', 'LVA', 'LTU', 'CYP', 'MLT'],
    'MENA': ['SAU', 'ARE', 'KWT', 'QAT', 'BHR', 'OMN', 'IRN', 'IRQ', 'TUR', 'ISR', 'EGY', 'MAR', 'TUN', 'DZA', 'LBY', 'JOR', 'LBN'],
    'SSA': ['NGA', 'ZAF', 'KEN', 'GHA', 'ETH', 'TZA', 'UGA', 'CIV', 'CMR', 'SEN', 'ZMB', 'ZWE', 'MOZ', 'AGO', 'COD', 'BWA', 'MUS', 'NAM'],
    'Latin America': ['BRA', 'MEX', 'ARG', 'CHL', 'COL', 'PER', 'VEN', 'ECU', 'BOL', 'PRY', 'URY', 'CRI', 'PAN', 'DOM', 'GTM', 'HND', 'SLV', 'NIC', 'JAM', 'TTO'],
}

# Reverse map
iso_to_region = {}
for region, isos in region_map.items():
    for iso in isos:
        iso_to_region[iso] = region

zone_entrants_hist['region'] = zone_entrants_hist['iso3'].map(iso_to_region).fillna('Other')

# East Asia dummy
zone_entrants_hist['east_asia'] = (zone_entrants_hist['region'] == 'East Asia').astype(float)

# Logit with East Asia dummy
logit_vars_region = ['Z_1_entry', 'gdp_pc_ppp_entry', 'east_asia']
ze_sample = zone_entrants_hist[['exited_above'] + logit_vars_region].dropna()

if len(ze_sample) >= 15:
    model_region = sm.Logit(ze_sample['exited_above'],
                            sm.add_constant(ze_sample[logit_vars_region])).fit(disp=0)
    print(f"\nLogit with East Asia control (n={len(ze_sample)}):")
    for var in logit_vars_region:
        sig = '***' if model_region.pvalues[var] < 0.01 else '**' if model_region.pvalues[var] < 0.05 else '*' if model_region.pvalues[var] < 0.1 else ''
        print(f"  {var}: coef={model_region.params[var]:.4f}, p={model_region.pvalues[var]:.4f}{sig}")
    print(f"  pseudo-R²={model_region.prsquared:.3f}")

# Region dummies
for region in ['East Asia', 'Europe', 'Latin America', 'MENA']:
    zone_entrants_hist[f'region_{region.replace(" ", "_")}'] = (zone_entrants_hist['region'] == region).astype(float)

region_dummies = [c for c in zone_entrants_hist.columns if c.startswith('region_')]
logit_vars_full_region = ['Z_1_entry', 'gdp_pc_ppp_entry'] + region_dummies
ze_sample_full = zone_entrants_hist[['exited_above'] + logit_vars_full_region].dropna()

if len(ze_sample_full) >= 15:
    try:
        model_full_region = sm.Logit(ze_sample_full['exited_above'],
                                     sm.add_constant(ze_sample_full[logit_vars_full_region])).fit(disp=0)
        print(f"\nLogit with all region dummies (n={len(ze_sample_full)}):")
        for var in logit_vars_full_region:
            sig = '***' if model_full_region.pvalues[var] < 0.01 else '**' if model_full_region.pvalues[var] < 0.05 else '*' if model_full_region.pvalues[var] < 0.1 else ''
            print(f"  {var}: coef={model_full_region.params[var]:.4f}, p={model_full_region.pvalues[var]:.4f}{sig}")
        print(f"  pseudo-R²={model_full_region.prsquared:.3f}")
    except:
        print("  Full region model failed (likely multicollinearity)")

# ═══════════════════════════════════════════════════════════════════════
# 6e. Reverse causality: does crossing predict demographic change?
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("6e. REVERSE CAUSALITY: DOES CROSSING PREDICT ΔZ₁?")
print("=" * 70)

# Among zone entrants, does exiting above predict subsequent Z₁ change?
# If reverse causality: getting rich → lower fertility → higher Z₁
# We test: does Z₁ at entry predict crossing, controlling for subsequent Z₁ change?

zone_with_dz1 = zone_entrants_hist[['iso3', 'exited_above', 'Z_1_entry', 'gdp_pc_ppp_entry']].copy()
zone_with_dz1 = zone_with_dz1.merge(
    cross_df[['iso3', 'Z_1_change_in_zone']].dropna(),
    on='iso3', how='left'
)

rc_sample = zone_with_dz1[['exited_above', 'Z_1_entry', 'gdp_pc_ppp_entry', 'Z_1_change_in_zone']].dropna()

if len(rc_sample) >= 15:
    # Does ΔZ₁ during transit differ by outcome?
    c_dz1 = rc_sample[rc_sample['exited_above'] == 1]['Z_1_change_in_zone']
    nc_dz1 = rc_sample[rc_sample['exited_above'] == 0]['Z_1_change_in_zone']
    t, p = stats.ttest_ind(c_dz1, nc_dz1, equal_var=False)
    print(f"ΔZ₁ during transit: crossers={c_dz1.mean():.3f}, non-crossers={nc_dz1.mean():.3f}, p={p:.4f}")

    # Logit controlling for ΔZ₁
    model_rc = sm.Logit(rc_sample['exited_above'],
                        sm.add_constant(rc_sample[['Z_1_entry', 'gdp_pc_ppp_entry', 'Z_1_change_in_zone']])).fit(disp=0)
    print(f"\nLogit controlling for ΔZ₁ during transit (n={len(rc_sample)}):")
    for var in ['Z_1_entry', 'gdp_pc_ppp_entry', 'Z_1_change_in_zone']:
        sig = '***' if model_rc.pvalues[var] < 0.01 else '**' if model_rc.pvalues[var] < 0.05 else '*' if model_rc.pvalues[var] < 0.1 else ''
        print(f"  {var}: coef={model_rc.params[var]:.4f}, p={model_rc.pvalues[var]:.4f}{sig}")
    print(f"  Z₁_entry survives control for ΔZ₁: {'YES' if model_rc.pvalues['Z_1_entry'] < 0.1 else 'NO'}")

# ═══════════════════════════════════════════════════════════════════════
# 6f. Panel regression robustness: zone interaction with alt thresholds
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("6f. PANEL CA REGRESSIONS WITH ALTERNATIVE ZONE DEFINITIONS")
print("=" * 70)

controls = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth', 'log_rel_opw', 'kaopen']

for label, lower, upper in threshold_sets:
    panel_temp = panel.copy()
    panel_temp['in_zone'] = ((panel_temp['gdp_pc_ppp'] >= lower) & (panel_temp['gdp_pc_ppp'] <= upper)).astype(float)
    panel_temp['Z_1_x_zone'] = panel_temp['Z_1'] * panel_temp['in_zone']

    xvars = ['Z_1', 'in_zone', 'Z_1_x_zone'] + controls
    subset = panel_temp[['iso3', 'year', 'ca_gdp'] + xvars].dropna()

    if len(subset) >= 100:
        gls = PanelGLS()
        gls.fit(subset['ca_gdp'], subset[xvars], subset['iso3'], subset['year'])

        z1_idx = 0
        zone_idx = xvars.index('Z_1_x_zone')
        z1_sig = '***' if gls.pvalues[z1_idx] < 0.01 else '**' if gls.pvalues[z1_idx] < 0.05 else '*' if gls.pvalues[z1_idx] < 0.1 else ''
        int_sig = '***' if gls.pvalues[zone_idx] < 0.01 else '**' if gls.pvalues[zone_idx] < 0.05 else '*' if gls.pvalues[zone_idx] < 0.1 else ''
        print(f"  {label}: Z₁={gls.beta[z1_idx]:.3f}{z1_sig}, Z₁×zone={gls.beta[zone_idx]:.3f}{int_sig}, "
              f"n={gls.n_obs}, R²={gls.r_squared:.3f}")

# ═══════════════════════════════════════════════════════════════════════
# Save robustness table
# ═══════════════════════════════════════════════════════════════════════
pd.DataFrame(robustness_results).to_csv(
    '/mnt/c/demographics_capital_flows/development_threshold/output/tables/table15_robustness_thresholds.csv', index=False)

print("\n✓ Phase 6 complete.")
